{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "-ETtu9CvVMDR"
   },
   "source": [
    "<h1>Chapter 10 - Creating Text Embedding Models</h1>\n",
    "<i>Exploring methods for both training and fine-tuning embedding models.</i>\n",
    "\n",
    "<a href=\"https://www.amazon.com/Hands-Large-Language-Models-Understanding/dp/1098150961\"><img src=\"https://img.shields.io/badge/Buy%20the%20Book!-grey?logo=amazon\"></a>\n",
    "<a href=\"https://www.oreilly.com/library/view/hands-on-large-language/9781098150952/\"><img src=\"https://img.shields.io/badge/O'Reilly-white.svg?logo=\"></a>\n",
    "<a href=\"https://github.com/HandsOnLLM/Hands-On-Large-Language-Models\"><img src=\"https://img.shields.io/badge/GitHub%20Repository-black?logo=github\"></a>\n",
    "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/HandsOnLLM/Hands-On-Large-Language-Models/blob/main/chapter10/Chapter%2010%20-%20Creating%20Text%20Embedding%20Models.ipynb)\n",
    "\n",
    "---\n",
    "\n",
    "This notebook is for Chapter 10 of the [Hands-On Large Language Models](https://www.amazon.com/Hands-Large-Language-Models-Understanding/dp/1098150961) book by [Jay Alammar](https://www.linkedin.com/in/jalammar) and [Maarten Grootendorst](https://www.linkedin.com/in/mgrootendorst/).\n",
    "\n",
    "---\n",
    "\n",
    "<a href=\"https://www.amazon.com/Hands-Large-Language-Models-Understanding/dp/1098150961\">\n",
    "<img src=\"https://raw.githubusercontent.com/HandsOnLLM/Hands-On-Large-Language-Models/main/images/book_cover.png\" width=\"350\"/></a>\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### [OPTIONAL] - Installing Packages on <img src=\"https://colab.google/static/images/icons/colab.png\" width=100>\n",
    "\n",
    "If you are viewing this notebook on Google Colab (or any other cloud vendor), you need to **uncomment and run** the following codeblock to install the dependencies for this chapter:\n",
    "\n",
    "---\n",
    "\n",
    "💡 **NOTE**: We will want to use a GPU to run the examples in this notebook. In Google Colab, go to\n",
    "**Runtime > Change runtime type > Hardware accelerator > GPU > GPU type > T4**.\n",
    "\n",
    "---\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# %%capture\n",
    "# !pip install -q accelerate>=0.27.2 peft>=0.9.0 bitsandbytes>=0.43.0 transformers>=4.38.2 trl>=0.7.11 sentencepiece>=0.1.99\n",
    "# !pip install -q sentence-transformers>=3.0.0 mteb>=1.1.2 datasets>=2.18.0"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "2UrKluX5YNmu"
   },
   "source": [
    "# Creating an Embedding Model"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ywsyZzm5VSER"
   },
   "source": [
    "## **Data**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 6529,
     "status": "ok",
     "timestamp": 1717342944433,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -120
    },
    "id": "Ahk0SJDKVy6F",
    "outputId": "497309ee-333a-4a6c-f008-dd6262a7a52f"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/dist-packages/huggingface_hub/utils/_token.py:89: UserWarning: \n",
      "The secret `HF_TOKEN` does not exist in your Colab secrets.\n",
      "To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.\n",
      "You will be able to reuse this secret in all of your notebooks.\n",
      "Please note that authentication is recommended but still optional to access public models or datasets.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from datasets import load_dataset\n",
    "\n",
    "# Load MNLI dataset from GLUE\n",
    "# 0 = entailment, 1 = neutral, 2 = contradiction\n",
    "train_dataset = load_dataset(\"glue\", \"mnli\", split=\"train\").select(range(50_000))\n",
    "train_dataset = train_dataset.remove_columns(\"idx\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 3,
     "status": "ok",
     "timestamp": 1717342944434,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -120
    },
    "id": "t-BHO4-qwMDO",
    "outputId": "f6671b92-7319-48bb-96a3-c848b45dee33"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'premise': 'One of our number will carry out your instructions minutely.',\n",
       " 'hypothesis': 'A member of my team will execute your orders with immense precision.',\n",
       " 'label': 0}"
      ]
     },
     "execution_count": 3,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train_dataset[2]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "5wO23cXLXeFU"
   },
   "source": [
    "## **Model**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 19919,
     "status": "ok",
     "timestamp": 1717342964351,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -120
    },
    "id": "C4qLaPR6nrqC",
    "outputId": "76fa9f0a-9c99-4e67-be82-70f9c41ba1b8"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:sentence_transformers.SentenceTransformer:No sentence-transformers model found with name bert-base-uncased. Creating a new one with mean pooling.\n",
      "/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n"
     ]
    }
   ],
   "source": [
    "from sentence_transformers import SentenceTransformer\n",
    "\n",
    "# Use a base model\n",
    "embedding_model = SentenceTransformer('bert-base-uncased')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "pAiL21AuYKVI"
   },
   "source": [
    "## **Loss Function**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "OgmtKckBXiK9"
   },
   "outputs": [],
   "source": [
    "from sentence_transformers import losses\n",
    "\n",
    "# Define the loss function. In soft-max loss, we will also need to explicitly set the number of labels.\n",
    "train_loss = losses.SoftmaxLoss(\n",
    "    model=embedding_model,\n",
    "    sentence_embedding_dimension=embedding_model.get_sentence_embedding_dimension(),\n",
    "    num_labels=3\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tH0efspwlOX2"
   },
   "source": [
    "## Evaluation"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "f8ZsoY0AretV"
   },
   "outputs": [],
   "source": [
    "from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator\n",
    "\n",
    "# Create an embedding similarity evaluator for stsb\n",
    "val_sts = load_dataset('glue', 'stsb', split='validation')\n",
    "evaluator = EmbeddingSimilarityEvaluator(\n",
    "    sentences1=val_sts[\"sentence1\"],\n",
    "    sentences2=val_sts[\"sentence2\"],\n",
    "    scores=[score/5 for score in val_sts[\"label\"]],\n",
    "    main_similarity=\"cosine\",\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "umikSmoYIP07"
   },
   "source": [
    "## **Training**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "8uAAhNs0ocoV"
   },
   "outputs": [],
   "source": [
    "from sentence_transformers.training_args import SentenceTransformerTrainingArguments\n",
    "\n",
    "# Define the training arguments\n",
    "args = SentenceTransformerTrainingArguments(\n",
    "    output_dir=\"base_embedding_model\",\n",
    "    num_train_epochs=1,\n",
    "    per_device_train_batch_size=32,\n",
    "    per_device_eval_batch_size=32,\n",
    "    warmup_steps=100,\n",
    "    fp16=True,\n",
    "    eval_steps=100,\n",
    "    logging_steps=100,\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 618,
     "referenced_widgets": [
      "459f20375cf24b51bd89cf541e2b63fa",
      "c27d0f3225924807bc9f0dc18b926424",
      "739d1e36723d4ee09430c2bb98fd7766",
      "3f2b25e57f3a4bf69b18ba02d8128092",
      "56245fdec01446f38b956da92a296306",
      "5d471644fab242568d37184e9bbcb086",
      "ed4c445ced6a4ad3bd5b8e5a56ea23fb",
      "d26a9fc85ca849b99fbead500a295f4a",
      "e11b9173c7e34062be45e43737aa9eb5",
      "049ddce06bfc4f1e8c81d412814a4e74",
      "b4eb0759b0874712a9736da2c8203689",
      "3bad6bb0058f45e1aa59fc4fa44e5a6f",
      "f0756c62e13741f9b574e6a30c5e9aae",
      "61131cb0ea4e4d83a9a2debc6557435d",
      "bada4e4dedfa46d09cf544dc2647801a",
      "ca673a5e949f4eb79ce87c8b8aac954c",
      "1429dde1f00e4c74a6680739ed6355ee",
      "1428f068da84478faa6ad723744f78ec",
      "f347198f211849bd8cacef1b9094109d",
      "840dd9e2a63147ea8ce9c1a5423a2451",
      "7b8f7b0d76e9419a8677510d65f7dc1b",
      "b7a8360bf9204c5a9b7c0b87555cbdd3",
      "5446d5dc6ce047eca1b11b91bae71c81",
      "7d2fa3fb9bb143f3b5f31f74276651dd",
      "a11c52e4d7dd4c39995561c495ae5fa4",
      "98e9f40ea3274ef5bd4593e2cee97242",
      "cb0646d4e66246669a31359ca3e6476a",
      "3334b517bae2434588cf1fac55a57182",
      "c2affd1e287640a3967ef68344e86247",
      "406f19fb77b1492691851a0222bf5a28",
      "80ff5c0e8c3148fdaa1a0a9be16f812a",
      "e426aaf2f15d46b88d9ffb7855744045",
      "ee32bcff9df042a4a9cff5437ef61a46"
     ]
    },
    "executionInfo": {
     "elapsed": 374122,
     "status": "ok",
     "timestamp": 1717343342445,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -120
    },
    "id": "JKA_L39FpAoM",
    "outputId": "8091a7c2-585f-4668-9f83-3ca1a06acd79"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='1563' max='1563' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [1563/1563 06:10, Epoch 1/1]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>1.080700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>0.959400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>300</td>\n",
       "      <td>0.916200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>400</td>\n",
       "      <td>0.870200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>0.849100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>600</td>\n",
       "      <td>0.854200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>700</td>\n",
       "      <td>0.835200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>800</td>\n",
       "      <td>0.825200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>900</td>\n",
       "      <td>0.818100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>0.800300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1100</td>\n",
       "      <td>0.781600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1200</td>\n",
       "      <td>0.777100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1300</td>\n",
       "      <td>0.786600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1400</td>\n",
       "      <td>0.767900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1500</td>\n",
       "      <td>0.797100</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "459f20375cf24b51bd89cf541e2b63fa",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Computing widget examples:   0%|          | 0/5 [00:00<?, ?example/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3bad6bb0058f45e1aa59fc4fa44e5a6f",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Computing widget examples:   0%|          | 0/5 [00:00<?, ?example/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5446d5dc6ce047eca1b11b91bae71c81",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Computing widget examples:   0%|          | 0/5 [00:00<?, ?example/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=1563, training_loss=0.8453957184872716, metrics={'train_runtime': 372.5713, 'train_samples_per_second': 134.202, 'train_steps_per_second': 4.195, 'total_flos': 0.0, 'train_loss': 0.8453957184872716, 'epoch': 1.0})"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sentence_transformers.trainer import SentenceTransformerTrainer\n",
    "\n",
    "# Train embedding model\n",
    "trainer = SentenceTransformerTrainer(\n",
    "    model=embedding_model,\n",
    "    args=args,\n",
    "    train_dataset=train_dataset,\n",
    "    loss=train_loss,\n",
    "    evaluator=evaluator\n",
    ")\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 2838,
     "status": "ok",
     "timestamp": 1717343345280,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -120
    },
    "id": "_NA16lEaseOq",
    "outputId": "b5d86d22-480d-4c5c-efd1-8e80c03dab4d"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'pearson_cosine': 0.3710938716460552,\n",
       " 'spearman_cosine': 0.45148122260403883,\n",
       " 'pearson_manhattan': 0.4037396904694362,\n",
       " 'spearman_manhattan': 0.4396893995197567,\n",
       " 'pearson_euclidean': 0.390788259199341,\n",
       " 'spearman_euclidean': 0.43444104358464286,\n",
       " 'pearson_dot': 0.3392927926047231,\n",
       " 'spearman_dot': 0.3530708415227247,\n",
       " 'pearson_max': 0.4037396904694362,\n",
       " 'spearman_max': 0.45148122260403883}"
      ]
     },
     "execution_count": 9,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Evaluate our trained model\n",
    "evaluator(embedding_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "M9xjkvCWwrp_"
   },
   "source": [
    "# MTEB"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 316
    },
    "executionInfo": {
     "elapsed": 33626,
     "status": "ok",
     "timestamp": 1717343378904,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -120
    },
    "id": "7Hueu4upxVYb",
    "outputId": "c3fe0828-a480-49ff-efe8-fe42b5b46db3"
   },
   "outputs": [
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"color: #262626; text-decoration-color: #262626\">───────────────────────────────────────────────── </span><span style=\"font-weight: bold\">Selected tasks </span><span style=\"color: #262626; text-decoration-color: #262626\"> ─────────────────────────────────────────────────</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[38;5;235m───────────────────────────────────────────────── \u001b[0m\u001b[1mSelected tasks \u001b[0m\u001b[38;5;235m ─────────────────────────────────────────────────\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\"><span style=\"font-weight: bold\">Classification</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\u001b[1mClassification\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">    - Banking77Classification, <span style=\"color: #626262; text-decoration-color: #626262; font-style: italic\">s2s</span>\n",
       "</pre>\n"
      ],
      "text/plain": [
       "    - Banking77Classification, \u001b[3;38;5;241ms2s\u001b[0m\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "<pre style=\"white-space:pre;overflow-x:auto;line-height:normal;font-family:Menlo,'DejaVu Sans Mono',consolas,'Courier New',monospace\">\n",
       "\n",
       "</pre>\n"
      ],
      "text/plain": [
       "\n",
       "\n"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/dist-packages/joblib/externals/loky/backend/fork_exec.py:38: RuntimeWarning: os.fork() was called. os.fork() is incompatible with multithreaded code, and JAX is multithreaded, so this will likely lead to a deadlock.\n",
      "  pid = os.fork()\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'Banking77Classification': {'mteb_version': '1.1.2',\n",
       "  'dataset_revision': '0fd18e25b25c072e09e0d92ab615fda904d66300',\n",
       "  'mteb_dataset_name': 'Banking77Classification',\n",
       "  'test': {'accuracy': 0.46022727272727276,\n",
       "   'f1': 0.45802738001849663,\n",
       "   'accuracy_stderr': 0.009556987908238961,\n",
       "   'f1_stderr': 0.01072225943077292,\n",
       "   'main_score': 0.46022727272727276,\n",
       "   'evaluation_time': 29.63}}}"
      ]
     },
     "execution_count": 10,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from mteb import MTEB\n",
    "\n",
    "# Choose evaluation task\n",
    "evaluation = MTEB(tasks=[\"Banking77Classification\"])\n",
    "\n",
    "# Calculate results\n",
    "results = evaluation.run(embedding_model)\n",
    "results"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "56V2ma89uJwN"
   },
   "source": [
    "⚠️ **VRAM Clean-up** - You will need to run the code below to partially empty the VRAM (GPU RAM). If that does not work, it is advised to restart the notebook instead. You can check the resources on the right-hand side (if you are using Google Colab) to check whether the used VRAM is indeed low. You can also run `!nivia-smi` to check current usage."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "c3LX1G0_4QCv"
   },
   "outputs": [],
   "source": [
    "# # Empty and delete trainer/model\n",
    "# trainer.accelerator.clear()\n",
    "# del trainer, embedding_model\n",
    "\n",
    "# # Garbage collection and empty cache\n",
    "# import gc\n",
    "# import torch\n",
    "\n",
    "# gc.collect()\n",
    "# torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6d0GcY8cnNs4"
   },
   "outputs": [],
   "source": [
    "import gc\n",
    "import torch\n",
    "\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "jYnRRSDN06eB"
   },
   "source": [
    "# Loss Fuctions"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "vuSCWbFO7RRM"
   },
   "source": [
    "⚠️ **VRAM Clean-up**\n",
    "* `Restart` the notebook in order to clean-up memory if you move on to the next training example."
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Tq8Yb6IB2LFI"
   },
   "source": [
    "## Cosine Similarity Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "qEmnjQQPuszQ"
   },
   "outputs": [],
   "source": [
    "from datasets import Dataset, load_dataset\n",
    "\n",
    "# Load MNLI dataset from GLUE\n",
    "# 0 = entailment, 1 = neutral, 2 = contradiction\n",
    "train_dataset = load_dataset(\"glue\", \"mnli\", split=\"train\").select(range(50_000))\n",
    "train_dataset = train_dataset.remove_columns(\"idx\")\n",
    "\n",
    "# (neutral/contradiction)=0 and (entailment)=1\n",
    "mapping = {2: 0, 1: 0, 0:1}\n",
    "train_dataset = Dataset.from_dict({\n",
    "    \"sentence1\": train_dataset[\"premise\"],\n",
    "    \"sentence2\": train_dataset[\"hypothesis\"],\n",
    "    \"label\": [float(mapping[label]) for label in train_dataset[\"label\"]]\n",
    "})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "np5bMwgO5y8g"
   },
   "outputs": [],
   "source": [
    "from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator\n",
    "\n",
    "# Create an embedding similarity evaluator for stsb\n",
    "val_sts = load_dataset('glue', 'stsb', split='validation')\n",
    "evaluator = EmbeddingSimilarityEvaluator(\n",
    "    sentences1=val_sts[\"sentence1\"],\n",
    "    sentences2=val_sts[\"sentence2\"],\n",
    "    scores=[score/5 for score in val_sts[\"label\"]],\n",
    "    main_similarity=\"cosine\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 692,
     "referenced_widgets": [
      "c16cc9c062f547ee938537a719d4f668",
      "94e4f08e49bc4043917ce262cfd17fb1",
      "afc5056a4b1a49579fc52faed7ce12b3",
      "cbb2ad7701bb4324907b30ac30c058bb",
      "f21767a805b24d3fa93c647517d1c1ee",
      "8478533cc376477e97a80376e0fb3678",
      "e70ed406b6c24fd2b22d518bcff8ff0c",
      "74b7ff4db1c64725b2bdb67e212bfac3",
      "3071516aab1a407eb4224e98caf675c6",
      "319644a2e9204cd08017353e867779f4",
      "97a00463893941be8e9685a990d2ad6a",
      "055c320e1c4540fb9f2b293dd28ed569",
      "6d194ebdd6334b998be781281a49fea3",
      "7ac2fbaa7741499cb0423c3b509f1656",
      "27f1913c7cea41858091d98eb8f1ea9b",
      "0fdc301b86594a7a81731fbecfa5f2ef",
      "0264b9ef5adf43cb8c3d4516b56e9a03",
      "2d9423c5a1594b4c853ebd363f14d917",
      "d31dd92c800246798b73035c1f81deeb",
      "c96b9494c274423bb7c869b16b4b3be1",
      "58bff1c2f9bf49f99d5d1715d1aaee01",
      "60d91c4725314ae29ad0019985e7c73a",
      "677333dacf8a41d08e0b7fc605da27f9",
      "cad3e9f01c98445e8fee42c396653813",
      "4eccf40bc6684d22ab26edf01290d2ae",
      "7868fb9e79ca4925afca2223088c1aa3",
      "f3ba5c1671304318977ce8f6c1b47391",
      "78e6d195ff224c66ae58ac6dad70430c",
      "ccb9de8bcfce439aaadf3732e1b2a5f8",
      "979b54240a164b70942340c79ab53777",
      "f7eeb347d6224eb48bbbc5879dbc80ae",
      "8e25d420a21b4649ac1ed6b4656098b3",
      "ccbe3bb6aabd4fbdaedad6755886560c"
     ]
    },
    "executionInfo": {
     "elapsed": 366439,
     "status": "ok",
     "timestamp": 1717343750576,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -120
    },
    "id": "Ikky866vdseY",
    "outputId": "803f2abe-002a-481c-9403-8ce39dfa5b47"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:sentence_transformers.SentenceTransformer:No sentence-transformers model found with name bert-base-uncased. Creating a new one with mean pooling.\n",
      "/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='1563' max='1563' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [1563/1563 06:04, Epoch 1/1]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>0.231900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>0.168900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>300</td>\n",
       "      <td>0.170900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>400</td>\n",
       "      <td>0.157800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>0.152900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>600</td>\n",
       "      <td>0.156100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>700</td>\n",
       "      <td>0.149300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>800</td>\n",
       "      <td>0.154500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>900</td>\n",
       "      <td>0.150900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>0.145600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1100</td>\n",
       "      <td>0.147800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1200</td>\n",
       "      <td>0.145600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1300</td>\n",
       "      <td>0.145100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1400</td>\n",
       "      <td>0.142000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1500</td>\n",
       "      <td>0.141600</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c16cc9c062f547ee938537a719d4f668",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Computing widget examples:   0%|          | 0/5 [00:00<?, ?example/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "055c320e1c4540fb9f2b293dd28ed569",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Computing widget examples:   0%|          | 0/5 [00:00<?, ?example/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "677333dacf8a41d08e0b7fc605da27f9",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Computing widget examples:   0%|          | 0/5 [00:00<?, ?example/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=1563, training_loss=0.15676780793427353, metrics={'train_runtime': 364.4779, 'train_samples_per_second': 137.183, 'train_steps_per_second': 4.288, 'total_flos': 0.0, 'train_loss': 0.15676780793427353, 'epoch': 1.0})"
      ]
     },
     "execution_count": 15,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sentence_transformers import losses, SentenceTransformer\n",
    "from sentence_transformers.trainer import SentenceTransformerTrainer\n",
    "from sentence_transformers.training_args import SentenceTransformerTrainingArguments\n",
    "\n",
    "# Define model\n",
    "embedding_model = SentenceTransformer('bert-base-uncased')\n",
    "\n",
    "# Loss function\n",
    "train_loss = losses.CosineSimilarityLoss(model=embedding_model)\n",
    "\n",
    "# Define the training arguments\n",
    "args = SentenceTransformerTrainingArguments(\n",
    "    output_dir=\"cosineloss_embedding_model\",\n",
    "    num_train_epochs=1,\n",
    "    per_device_train_batch_size=32,\n",
    "    per_device_eval_batch_size=32,\n",
    "    warmup_steps=100,\n",
    "    fp16=True,\n",
    "    eval_steps=100,\n",
    "    logging_steps=100,\n",
    ")\n",
    "\n",
    "# Train model\n",
    "trainer = SentenceTransformerTrainer(\n",
    "    model=embedding_model,\n",
    "    args=args,\n",
    "    train_dataset=train_dataset,\n",
    "    loss=train_loss,\n",
    "    evaluator=evaluator\n",
    ")\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 3546,
     "status": "ok",
     "timestamp": 1717343754119,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -120
    },
    "id": "E69gBMG46WVF",
    "outputId": "7b1ecdbb-6c5d-4152-d10a-860d7cb587d9"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'pearson_cosine': 0.7222320710908391,\n",
       " 'spearman_cosine': 0.725059765496038,\n",
       " 'pearson_manhattan': 0.7338172618636865,\n",
       " 'spearman_manhattan': 0.7323465534428775,\n",
       " 'pearson_euclidean': 0.7332726423686017,\n",
       " 'spearman_euclidean': 0.7316943270141215,\n",
       " 'pearson_dot': 0.6603672299249149,\n",
       " 'spearman_dot': 0.6624301208511642,\n",
       " 'pearson_max': 0.7338172618636865,\n",
       " 'spearman_max': 0.7323465534428775}"
      ]
     },
     "execution_count": 16,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Evaluate our trained model\n",
    "evaluator(embedding_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "wnGpHqy46YS3"
   },
   "source": [
    "⚠️ **VRAM Clean-up**\n",
    "* `Restart` the notebook in order to clean-up memory if you move on to the next training example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Xvps7UpznPD4"
   },
   "outputs": [],
   "source": [
    "import gc\n",
    "import torch\n",
    "\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "3yh0toLx2Ni7"
   },
   "source": [
    "## Multiple Negatives Ranking Loss"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 3147,
     "status": "ok",
     "timestamp": 1717343758130,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -120
    },
    "id": "xzToWFH0vZzz",
    "outputId": "d3dc969a-8fbd-4f5e-9ff0-b85ce9851312"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "16875it [00:01, 14110.96it/s]\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "16875"
      ]
     },
     "execution_count": 18,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "import random\n",
    "from tqdm import tqdm\n",
    "from datasets import Dataset, load_dataset\n",
    "\n",
    "# # Load MNLI dataset from GLUE\n",
    "mnli = load_dataset(\"glue\", \"mnli\", split=\"train\").select(range(50_000))\n",
    "mnli = mnli.remove_columns(\"idx\")\n",
    "mnli = mnli.filter(lambda x: True if x['label'] == 0 else False)\n",
    "\n",
    "# Prepare data and add a soft negative\n",
    "train_dataset = {\"anchor\": [], \"positive\": [], \"negative\": []}\n",
    "soft_negatives = mnli[\"hypothesis\"]\n",
    "random.shuffle(soft_negatives)\n",
    "for row, soft_negative in tqdm(zip(mnli, soft_negatives)):\n",
    "    train_dataset[\"anchor\"].append(row[\"premise\"])\n",
    "    train_dataset[\"positive\"].append(row[\"hypothesis\"])\n",
    "    train_dataset[\"negative\"].append(soft_negative)\n",
    "train_dataset = Dataset.from_dict(train_dataset)\n",
    "len(train_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "wP_s1yAB7D7I"
   },
   "outputs": [],
   "source": [
    "from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator\n",
    "\n",
    "# Create an embedding similarity evaluator for stsb\n",
    "val_sts = load_dataset('glue', 'stsb', split='validation')\n",
    "evaluator = EmbeddingSimilarityEvaluator(\n",
    "    sentences1=val_sts[\"sentence1\"],\n",
    "    sentences2=val_sts[\"sentence2\"],\n",
    "    scores=[score/5 for score in val_sts[\"label\"]],\n",
    "    main_similarity=\"cosine\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 379,
     "referenced_widgets": [
      "5913c2edc4de4e48990626d19af0ff2b",
      "d39dba3810c94462bee671fe31f85691",
      "e9bd9646a1e14113b4effd2f594f73c4",
      "ae508d9b92854976ba0cf914c8bf2ad0",
      "59e3baf2f35a40998051f4ec0d0e5f1b",
      "c0a5f6b6e620466dadc2806a3f39ae24",
      "08834e2157ae46e6b346cccb936fbf81",
      "f4ee4ed0429949d7b8ccc58310544785",
      "1d8dd9db554449c7980ed62cc0f33461",
      "e84dc9109f1c4ccd81e40d246f772b57",
      "ffb89bdce63844a09fe31cb1178f1e5b"
     ]
    },
    "executionInfo": {
     "elapsed": 182305,
     "status": "ok",
     "timestamp": 1717343942563,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -120
    },
    "id": "j-Q2m0yzkRvW",
    "outputId": "7f83f28d-79d1-4b1e-d70a-85f2cacee1c1"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:sentence_transformers.SentenceTransformer:No sentence-transformers model found with name bert-base-uncased. Creating a new one with mean pooling.\n",
      "/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='528' max='528' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [528/528 03:00, Epoch 1/1]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>0.345200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>0.105500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>300</td>\n",
       "      <td>0.079000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>400</td>\n",
       "      <td>0.062200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>0.069000</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "5913c2edc4de4e48990626d19af0ff2b",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Computing widget examples:   0%|          | 0/5 [00:00<?, ?example/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=528, training_loss=0.12795479415041028, metrics={'train_runtime': 180.866, 'train_samples_per_second': 93.301, 'train_steps_per_second': 2.919, 'total_flos': 0.0, 'train_loss': 0.12795479415041028, 'epoch': 1.0})"
      ]
     },
     "execution_count": 20,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sentence_transformers import losses, SentenceTransformer\n",
    "from sentence_transformers.trainer import SentenceTransformerTrainer\n",
    "from sentence_transformers.training_args import SentenceTransformerTrainingArguments\n",
    "\n",
    "# Define model\n",
    "embedding_model = SentenceTransformer('bert-base-uncased')\n",
    "\n",
    "# Loss function\n",
    "train_loss = losses.MultipleNegativesRankingLoss(model=embedding_model)\n",
    "\n",
    "# Define the training arguments\n",
    "args = SentenceTransformerTrainingArguments(\n",
    "    output_dir=\"mnrloss_embedding_model\",\n",
    "    num_train_epochs=1,\n",
    "    per_device_train_batch_size=32,\n",
    "    per_device_eval_batch_size=32,\n",
    "    warmup_steps=100,\n",
    "    fp16=True,\n",
    "    eval_steps=100,\n",
    "    logging_steps=100,\n",
    ")\n",
    "\n",
    "# Train model\n",
    "trainer = SentenceTransformerTrainer(\n",
    "    model=embedding_model,\n",
    "    args=args,\n",
    "    train_dataset=train_dataset,\n",
    "    loss=train_loss,\n",
    "    evaluator=evaluator\n",
    ")\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 3617,
     "status": "ok",
     "timestamp": 1717343946178,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -120
    },
    "id": "YvPEvgf98uS8",
    "outputId": "28a0b598-4553-4da7-ae5f-159698dd3255"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'pearson_cosine': 0.8070727434643791,\n",
       " 'spearman_cosine': 0.8106193672462586,\n",
       " 'pearson_manhattan': 0.8213132116968124,\n",
       " 'spearman_manhattan': 0.8164551132664518,\n",
       " 'pearson_euclidean': 0.820988086354926,\n",
       " 'spearman_euclidean': 0.8160139830687847,\n",
       " 'pearson_dot': 0.7429357515240518,\n",
       " 'spearman_dot': 0.7316164586329814,\n",
       " 'pearson_max': 0.8213132116968124,\n",
       " 'spearman_max': 0.8164551132664518}"
      ]
     },
     "execution_count": 21,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Evaluate our trained model\n",
    "evaluator(embedding_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ND1ej1ag054E"
   },
   "source": [
    "# **Fine-tuning**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "tqATI-1V7coM"
   },
   "source": [
    "⚠️ **VRAM Clean-up**\n",
    "* `Restart` the notebook in order to clean-up memory if you move on to the next training example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "OkWj0SfYnRFd"
   },
   "outputs": [],
   "source": [
    "import gc\n",
    "import torch\n",
    "\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ZGxyXucEkjfw"
   },
   "source": [
    "## **Supervised**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5GXBTm_C-IPE"
   },
   "outputs": [],
   "source": [
    "from datasets import load_dataset\n",
    "from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator\n",
    "\n",
    "# Load MNLI dataset from GLUE\n",
    "# 0 = entailment, 1 = neutral, 2 = contradiction\n",
    "train_dataset = load_dataset(\"glue\", \"mnli\", split=\"train\").select(range(50_000))\n",
    "train_dataset = train_dataset.remove_columns(\"idx\")\n",
    "\n",
    "# Create an embedding similarity evaluator for stsb\n",
    "val_sts = load_dataset('glue', 'stsb', split='validation')\n",
    "evaluator = EmbeddingSimilarityEvaluator(\n",
    "    sentences1=val_sts[\"sentence1\"],\n",
    "    sentences2=val_sts[\"sentence2\"],\n",
    "    scores=[score/5 for score in val_sts[\"label\"]],\n",
    "    main_similarity=\"cosine\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 1000,
     "referenced_widgets": [
      "a3a94fe38f134761a6998fa7a057eb7e",
      "b6813e95e0ef44ceb3a2e5547d5d3638",
      "cc49f6d66133462f99eacb2158131a4d",
      "7428954237c64ea9aa024f00b055bca3",
      "bef91caa09414e4a9dab354676994fc7",
      "d7ec2f12b4ae47979c429da424dfdce6",
      "e34d030373e746df84576c0c6d956fda",
      "bbe74e8f1c3d412895a5bfaff4e26d15",
      "4d44b380d5034a66b09bde71ed2b803a",
      "fb5eae27c64c421d886440024aa343d9",
      "fae4e813ca714f07b82fff44071bd464",
      "924aec3545864e3fb5fdcf962c21836a",
      "c314d05c3d7146d49edddd93aed4a5a1",
      "3b434bb3087e4ab299f933844f1f42ce",
      "32138e1691874ef58b723c859ee3416b",
      "00000a1e212149afbddc87276d9a9914",
      "77b9617224474504a03a1f2ecf27f038",
      "ed30ff2b08844714b668c0625dea7e98",
      "bd96c20d553a4c16bff8a2faa59caca0",
      "31ad69ee3fa74cc9b4ec22e94333af38",
      "08ef3e745846458aaf734055f29f8504",
      "5a41ab1f2cae468180064f16241c5d49",
      "cde32db4b98f41c9948c2f28c20a4c0a",
      "6f4d13899d314132866a48f900be8323",
      "9c35401c3dce44b485c1b411f4764df7",
      "3a95f39a58354022a26a218520b65564",
      "36e69b2801ef4e388db0a51c4254b110",
      "e6bef80c762f4d928352b03ce7f2cd0d",
      "ae49c4b942ef4487b9988de166d3b842",
      "3c1eef14cc244fa59532b7f145e4a4de",
      "0aeb21e9873c4f1da6c97a9e019fe466",
      "cb13d89f31cb472d94dc676012d48cb5",
      "a9b14052361a467db942709fa73d76f7",
      "bbbb12d403fc4edc9f27e3ef9927ffd8",
      "09779a3f2b5d4417a1f8be9c91793d39",
      "06cb51615ebf477291a045e258c84ff8",
      "f43b70e6ae6f490bb1dec4c7609deb6f",
      "aeb5d637d2944edb9bde0f9de72faf8b",
      "38ca2c76a6a54b12aa1548478c7c8f9e",
      "967cf768733440edbafa586b34d05f1f",
      "4e91807ba3e14ee494ec3b827eb904e8",
      "9447d576fdb1417ebed3b287840301d5",
      "ab60b7e545624816bb705a4e54d3d5d3",
      "8f8f536ed92b401f92fa8a3368f1e21a",
      "f8e627842b8c4342ac797c94121e83bf",
      "f98797f8ebb141bbaf0a1e45d1f34ee9",
      "668acd6e586d4797aeab2f60762a2f3e",
      "19d2d601f02d46929de86ec7afc97c91",
      "206e8c5194e041ac8040421a7f3dc3d7",
      "9cfa6786bd4f4013a21c98db571417ae",
      "b857ac58c2434efc8539a0fe4b16ae51",
      "5167f315d5ff4e93a4f91013b7f715ab",
      "08a6cfb598a04f85a860f769f1b74749",
      "304a1041ef364bd3b04478226f096f39",
      "3fa4355c43474c31b31e6f384ddcb367",
      "7a5a63ac37a94a009b7f6ba99f08e2da",
      "70fd7dba36d340bdb32af8f45eda9e8d",
      "c88a2375496640b7ad4dfc98ef220845",
      "493825b991b24299a76542031fb3b93f",
      "c9c90ac915e74c54a5d89791774492f8",
      "2e57db9adf69495fbae421c21050d1d9",
      "fe51a5f8862142799420fa5156987df7",
      "2bb2f11b9c564433aa13fb2957a0aa1e",
      "bc293034952c43a0b93e4e47719610ff",
      "79d478ec7beb43c7996d3c5437608ab7",
      "698787c87bd645459a6b3174cbf11357",
      "49d930971e0e4ef3a611eec49bf0f53e",
      "5ccd802fc480477db25a63ea3ca99c53",
      "c11c670643a04e6dbc032fe017cdd00c",
      "13e4cd853e6048d4b5368fa9344f727d",
      "38b364436ece48c69ef0ed2a11d049ba",
      "64e449d1577e484d96524a8e659d9844",
      "0260a96090cc4e3ba94ea315a9367709",
      "84afd14b8ef64a9fbaf9e0d39013986d",
      "4dde6a52b4404b658fb2503efd30ef51",
      "0648485e203d447b8a1712ec4b02ba74",
      "928648503f2d4571ad0c05b101a245d9",
      "6413d644285c47d08e2991171f2fedb6",
      "7505d8bd260b4e6387cb36d797817e8c",
      "01426c6fdf49417fb41936535799f4e8",
      "538a8524990242469f80eebf13606de0",
      "1ee0ef2217c84512848c5df4fd3104d1",
      "c846b94b350542d3b75a0e584ab3f405",
      "3361e75e6c8742a49545f610e0edd75c",
      "b057909b311a41528cd38477de6f45d3",
      "eb2b29b42f164a48baec6cf8f2e8d0e5",
      "c33b96eb81234f38821a3cc20a36d526",
      "c02b97d22c6246cead5eb05764dd5eeb",
      "aaf53993e3f1420b920b0afaf2919842",
      "7b6d904fe23748368f0952d4f434344c",
      "baa4d8f579f1484c9a8071a142b2f864",
      "ab3c279b3c144767ba629d50b6015a58",
      "2651661da9c64b4b95736fd574c7363d",
      "cb87e1b2613040148e61b06bce236cdd",
      "041003a4c86649c7a3943b55a57a681b",
      "1c9de49116b446d3a338b6b84ddb9eba",
      "bf4f4e0d9bd046ea94ad9fbb48af6df4",
      "0c36ef857d80437db1265cc1d8537a2f",
      "09975b41a4b448e1b084b2d5b78466ab",
      "22bad8429f9d4e4db13f79b7c36c79cb",
      "c6d04928e48b43cca98b7744898b51b5",
      "700a386e962c495c866975ac579194b6",
      "cdeed5d6108a4cae842022ed365e6906",
      "ac2888c58bac46f2ae0d00ba0c1f30d8",
      "5cf8d4b4109946139e2fc796e8aa7318",
      "b4d08071a3254c509e4a6d40f73a04e5",
      "59a254c7826f4ec4aeafb9792c2f42dd",
      "db97efdd7c60447e85491bf14dabb005",
      "f62fc29c15ac4310a0fc3b005032a7fc",
      "fc75e9fc84e44b4694b883e0c1dc350c",
      "3d885279ea8f4e1dacca331fbc0a9d0a",
      "46f92ff07bca4be3b5e208a0f0bf5b9f",
      "5ede26183db14c419c8dd21b6fb97b9b",
      "de6a0dfdfc574011ad79e47c730c8f3d",
      "9ab322b5529c4e6ca8c383f754d0da61",
      "829e885ee8684126a5feb131387c2bb4",
      "fbb52f8ababd49658ad696bc1a3306b6",
      "4ae67119ec9d4613bc418b8dd1878a1f",
      "308403acd75a43ee8271ab8b01f42a06",
      "e56e48a2c220471bb45b6c7720ac4adf",
      "ce36ee95348e4ffcbd0b0705398f104e",
      "511480884e6145cda98eac03e3eaa214",
      "b4da64c10a7245ec8e406bc60ced8377",
      "4f8723edce1046bc973f74fda2de11e1",
      "5f4fbb82a6d84b3898fee35a31a3fba9",
      "b4c0317ec5554c76906a7b2d942b9269",
      "d97c721beaf0476cbeee953fc2b31b3e",
      "a95563bc7ba94467b7e653b93f1d4c19",
      "2ebb224bce544c0bba1ad3324384da7b",
      "391709a514ca45138d7ce8c94c747b5d",
      "e191c02489c947c593c9321526d27b43",
      "e5fa8fe92e6e42faa17c9bebcc17c748",
      "dc3cccdae676461b991e6b375b482344",
      "5f9f805d69cb4aa3a8bca67088a1e5ae",
      "fd57703e5fe94245a9dc596cd5555211",
      "ce584d4254224a5ba66d7fd39e8ef387",
      "210d9c45015e433f9d3b79505f6db592",
      "50a0c0604547472cbc48657d2ccf3107",
      "f967688c64764354923d2af38cfb61ce",
      "4bd9c7e8e97144d4a8076adaeffa35e9",
      "f817a673080648bcb05d73d1bceac28d",
      "6741232d8ee94bc4b5afe4c4ecc56c5d",
      "fc2d101e2aff401ab6a5238dc0e6debc",
      "27bf3b15061241b5bc14c47e9713364e",
      "3de506a04f3d45e3a5e18872df01d2be",
      "719a1bde61c9482d8aa8da6c3c730034",
      "e524e0d7a70c4c42a9c0206aab89cd1a",
      "617fdebf725d45c48e706db1c6606280",
      "ae71a033b1304aa286a107cbcdede0a0",
      "2930ba418d214a4097d17daabd087714",
      "d45daa6f3a094f44956d181486da4b83",
      "bd2971fb392e41f5b8c56c4a6eba0150",
      "e52b7c3a666b4d67af39747a8d397a7a",
      "f99007a7ea614b5e8e9ae5ceb64b1da1"
     ]
    },
    "executionInfo": {
     "elapsed": 122810,
     "status": "ok",
     "timestamp": 1717344073501,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -120
    },
    "id": "iQ3vW2VTA9dN",
    "outputId": "8772abf8-992c-4ec1-9e27-2f90cf1e4503"
   },
   "outputs": [
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "a3a94fe38f134761a6998fa7a057eb7e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "924aec3545864e3fb5fdcf962c21836a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config_sentence_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "cde32db4b98f41c9948c2f28c20a4c0a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "README.md:   0%|          | 0.00/10.7k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "bbbb12d403fc4edc9f27e3ef9927ffd8",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "sentence_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "f8e627842b8c4342ac797c94121e83bf",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "config.json:   0%|          | 0.00/612 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "7a5a63ac37a94a009b7f6ba99f08e2da",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "model.safetensors:   0%|          | 0.00/90.9M [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "49d930971e0e4ef3a611eec49bf0f53e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer_config.json:   0%|          | 0.00/350 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "6413d644285c47d08e2991171f2fedb6",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "aaf53993e3f1420b920b0afaf2919842",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "22bad8429f9d4e4db13f79b7c36c79cb",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "special_tokens_map.json:   0%|          | 0.00/112 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "3d885279ea8f4e1dacca331fbc0a9d0a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "1_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/html": [
       "\n",
       "    <div>\n",
       "      \n",
       "      <progress value='1563' max='1563' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
       "      [1563/1563 01:57, Epoch 1/1]\n",
       "    </div>\n",
       "    <table border=\"1\" class=\"dataframe\">\n",
       "  <thead>\n",
       " <tr style=\"text-align: left;\">\n",
       "      <th>Step</th>\n",
       "      <th>Training Loss</th>\n",
       "    </tr>\n",
       "  </thead>\n",
       "  <tbody>\n",
       "    <tr>\n",
       "      <td>100</td>\n",
       "      <td>0.155500</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>200</td>\n",
       "      <td>0.110000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>300</td>\n",
       "      <td>0.118600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>400</td>\n",
       "      <td>0.115300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>500</td>\n",
       "      <td>0.110700</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>600</td>\n",
       "      <td>0.101000</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>700</td>\n",
       "      <td>0.113100</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>800</td>\n",
       "      <td>0.099800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>900</td>\n",
       "      <td>0.109600</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1000</td>\n",
       "      <td>0.105800</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1100</td>\n",
       "      <td>0.094900</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1200</td>\n",
       "      <td>0.106400</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1300</td>\n",
       "      <td>0.105300</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1400</td>\n",
       "      <td>0.105200</td>\n",
       "    </tr>\n",
       "    <tr>\n",
       "      <td>1500</td>\n",
       "      <td>0.106600</td>\n",
       "    </tr>\n",
       "  </tbody>\n",
       "</table><p>"
      ],
      "text/plain": [
       "<IPython.core.display.HTML object>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "511480884e6145cda98eac03e3eaa214",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Computing widget examples:   0%|          | 0/5 [00:00<?, ?example/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "dc3cccdae676461b991e6b375b482344",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Computing widget examples:   0%|          | 0/5 [00:00<?, ?example/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "27bf3b15061241b5bc14c47e9713364e",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Computing widget examples:   0%|          | 0/5 [00:00<?, ?example/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "text/plain": [
       "TrainOutput(global_step=1563, training_loss=0.10982195932897176, metrics={'train_runtime': 117.3739, 'train_samples_per_second': 425.989, 'train_steps_per_second': 13.316, 'total_flos': 0.0, 'train_loss': 0.10982195932897176, 'epoch': 1.0})"
      ]
     },
     "execution_count": 24,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "from sentence_transformers import losses, SentenceTransformer\n",
    "from sentence_transformers.trainer import SentenceTransformerTrainer\n",
    "from sentence_transformers.training_args import SentenceTransformerTrainingArguments\n",
    "\n",
    "# Define model\n",
    "embedding_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')\n",
    "\n",
    "# Loss function\n",
    "train_loss = losses.MultipleNegativesRankingLoss(model=embedding_model)\n",
    "\n",
    "# Define the training arguments\n",
    "args = SentenceTransformerTrainingArguments(\n",
    "    output_dir=\"finetuned_embedding_model\",\n",
    "    num_train_epochs=1,\n",
    "    per_device_train_batch_size=32,\n",
    "    per_device_eval_batch_size=32,\n",
    "    warmup_steps=100,\n",
    "    fp16=True,\n",
    "    eval_steps=100,\n",
    "    logging_steps=100,\n",
    ")\n",
    "\n",
    "# Train model\n",
    "trainer = SentenceTransformerTrainer(\n",
    "    model=embedding_model,\n",
    "    args=args,\n",
    "    train_dataset=train_dataset,\n",
    "    loss=train_loss,\n",
    "    evaluator=evaluator\n",
    ")\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 1856,
     "status": "ok",
     "timestamp": 1717344075355,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -120
    },
    "id": "MaPJIpkS-ZrT",
    "outputId": "58bd4fed-7c3e-428e-d1a9-da74b7e97998"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "{'pearson_cosine': 0.8489503881223601,\n",
       " 'spearman_cosine': 0.8484667083117318,\n",
       " 'pearson_manhattan': 0.8503843871673679,\n",
       " 'spearman_manhattan': 0.8475679105384369,\n",
       " 'pearson_euclidean': 0.8513072191805562,\n",
       " 'spearman_euclidean': 0.8484667083117318,\n",
       " 'pearson_dot': 0.8489503890256918,\n",
       " 'spearman_dot': 0.8484667083117318,\n",
       " 'pearson_max': 0.8513072191805562,\n",
       " 'spearman_max': 0.8484667083117318}"
      ]
     },
     "execution_count": 25,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Evaluate our trained model\n",
    "evaluator(embedding_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 2630,
     "status": "ok",
     "timestamp": 1717344077983,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -120
    },
    "id": "3pHpVCwmk-XW",
    "outputId": "22b999e9-3b78-41ba-f96f-a88add04a6c6"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "{'pearson_cosine': 0.8696194608752055,\n",
       " 'spearman_cosine': 0.8671637433378804,\n",
       " 'pearson_manhattan': 0.8670399009851635,\n",
       " 'spearman_manhattan': 0.8663946139224048,\n",
       " 'pearson_euclidean': 0.867871599362501,\n",
       " 'spearman_euclidean': 0.8671643653432983,\n",
       " 'pearson_dot': 0.8696194616795601,\n",
       " 'spearman_dot': 0.8671631197908374,\n",
       " 'pearson_max': 0.8696194616795601,\n",
       " 'spearman_max': 0.8671643653432983}"
      ]
     },
     "execution_count": 26,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# Evaluate the pre-trained model\n",
    "original_model = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2')\n",
    "evaluator(original_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ii6sIMpH7d7S"
   },
   "source": [
    "⚠️ **VRAM Clean-up**\n",
    "* `Restart` the notebook in order to clean-up memory if you move on to the next training example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "YGCTfC-unSL1"
   },
   "outputs": [],
   "source": [
    "import gc\n",
    "import torch\n",
    "\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "nvCPXCSZkkxm"
   },
   "source": [
    "## **Augmented SBERT**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CtoEArJElrZh"
   },
   "source": [
    "**Step 1:** Fine-tune a cross-encoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "executionInfo": {
     "elapsed": 1908,
     "status": "ok",
     "timestamp": 1717344080661,
     "user": {
      "displayName": "Maarten Grootendorst",
      "userId": "11015108362723620659"
     },
     "user_tz": -120
    },
    "id": "IJhEDAeeyhPD",
    "outputId": "fedfd47d-da9a-46b0-a93b-936b3c2eb025"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "100%|██████████| 10000/10000 [00:00<00:00, 25870.92it/s]\n"
     ]
    }
   ],
   "source": [
    "import pandas as pd\n",
    "from tqdm import tqdm\n",
    "from datasets import load_dataset, Dataset\n",
    "from sentence_transformers import InputExample\n",
    "from sentence_transformers.datasets import NoDuplicatesDataLoader\n",
    "\n",
    "# Prepare a small set of 10000 documents for the cross-encoder\n",
    "dataset = load_dataset(\"glue\", \"mnli\", split=\"train\").select(range(10_000))\n",
    "mapping = {2: 0, 1: 0, 0:1}\n",
    "\n",
    "# Data Loader\n",
    "gold_examples = [\n",
    "    InputExample(texts=[row[\"premise\"], row[\"hypothesis\"]], label=mapping[row[\"label\"]])\n",
    "    for row in tqdm(dataset)\n",
    "]\n",
    "gold_dataloader = NoDuplicatesDataLoader(gold_examples, batch_size=32)\n",
    "\n",
    "# Pandas DataFrame for easier data handling\n",
    "gold = pd.DataFrame(\n",
    "    {\n",
    "    'sentence1': dataset['premise'],\n",
    "    'sentence2': dataset['hypothesis'],\n",
    "    'label': [mapping[label] for label in dataset['label']]\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 174,
     "referenced_widgets": [
      "10fa8eb5de404747b62ecc8cd59e25fd",
      "f6bf22b46ab443728a9dc23543023875",
      "53d71e19181e41b3bb35a4b3f5024ebc",
      "a63862df0d904f3f9cca19e4fff9e81a",
      "aa924947fbfe44f8a5abdbb2f3773772",
      "af6b364698eb4974b15f9d1b1c2ee2b3",
      "1ce96fcbf1774d28984d3b5b1c64b1e3",
      "29ff06629da54435ba40c04a69664fc7",
      "e0b9c66de56f454c9441a0a12614640a",
      "651526114ca74747b73067b75a1765e1",
      "51cf65441cae4550af4cc125eb438702",
      "c8ca3c14b870402b8931635fa06ed32a",
      "350c7dd7e31c4df4bc0ef962dea83e9d",
      "393cd52e1286470f94dbba825945c8f6",
      "bf5ef57c25e645f1b6cc6c24bc94818c",
      "985561e2b3414808a65d357df5aefcc2",
      "df5b22c091204b33bc4a6af2e7fe4ec8",
      "9cbd98c3d46842cebe4a13d25108eb6b",
      "ee7fd722284b43f3a85da609c28f1d6d",
      "439896b44d0244d28f5669e5e6ff065f",
      "8fe2472211c74d2dad8c017122d185de",
      "b165e17da35c45e0832e31fba6e4e20f"
     ]
    },
    "id": "-_MHAJzl2H6Z",
    "outputId": "74ff2a68-17a6-47a4-cd48-6fd502edbcb6"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.\n",
      "  warnings.warn(\n",
      "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
      "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
     ]
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "10fa8eb5de404747b62ecc8cd59e25fd",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Epoch:   0%|          | 0/1 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    },
    {
     "data": {
      "application/vnd.jupyter.widget-view+json": {
       "model_id": "c8ca3c14b870402b8931635fa06ed32a",
       "version_major": 2,
       "version_minor": 0
      },
      "text/plain": [
       "Iteration:   0%|          | 0/312 [00:00<?, ?it/s]"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "from sentence_transformers.cross_encoder import CrossEncoder\n",
    "\n",
    "# Train a cross-encoder on the gold dataset\n",
    "cross_encoder = CrossEncoder('bert-base-uncased', num_labels=2)\n",
    "cross_encoder.fit(\n",
    "    train_dataloader=gold_dataloader,\n",
    "    epochs=1,\n",
    "    show_progress_bar=True,\n",
    "    warmup_steps=100,\n",
    "    use_amp=False\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "X0OcVG6WmMOJ"
   },
   "source": [
    "**Step 2:** Create new sentence pairs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fgx8N8a8kVrZ"
   },
   "outputs": [],
   "source": [
    "# Prepare the silver dataset by predicting labels with the cross-encoder\n",
    "silver = load_dataset(\"glue\", \"mnli\", split=\"train\").select(range(10_000, 50_000))\n",
    "pairs = list(zip(silver['premise'], silver['hypothesis']))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qcG7cDG5qrwX"
   },
   "source": [
    "**Step 3:** Label new sentence pairs with the fine-tuned cross-encoder (silver dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "O9Yuhzxq2NMj"
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "# Label the sentence pairs using our fine-tuned cross-encoder\n",
    "output = cross_encoder.predict(pairs, apply_softmax=True, show_progress_bar=True)\n",
    "silver = pd.DataFrame(\n",
    "    {\n",
    "        \"sentence1\": silver[\"premise\"],\n",
    "        \"sentence2\": silver[\"hypothesis\"],\n",
    "        \"label\": np.argmax(output, axis=1)\n",
    "    }\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "D9Jd-Kssqzk_"
   },
   "source": [
    "**Step 4:** Train a bi-encoder (SBERT) on the extended dataset (gold + silver dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fp09qxMhzagi"
   },
   "outputs": [],
   "source": [
    "# Combine gold + silver\n",
    "data = pd.concat([gold, silver], ignore_index=True, axis=0)\n",
    "data = data.drop_duplicates(subset=['sentence1', 'sentence2'], keep=\"first\")\n",
    "train_dataset = Dataset.from_pandas(data, preserve_index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "S-6RW_wOAOwO"
   },
   "outputs": [],
   "source": [
    "from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator\n",
    "\n",
    "# Create an embedding similarity evaluator for stsb\n",
    "val_sts = load_dataset('glue', 'stsb', split='validation')\n",
    "evaluator = EmbeddingSimilarityEvaluator(\n",
    "    sentences1=val_sts[\"sentence1\"],\n",
    "    sentences2=val_sts[\"sentence2\"],\n",
    "    scores=[score/5 for score in val_sts[\"label\"]],\n",
    "    main_similarity=\"cosine\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "MK1KybOI_uIY"
   },
   "outputs": [],
   "source": [
    "from sentence_transformers import losses, SentenceTransformer\n",
    "from sentence_transformers.trainer import SentenceTransformerTrainer\n",
    "from sentence_transformers.training_args import SentenceTransformerTrainingArguments\n",
    "\n",
    "# Define model\n",
    "embedding_model = SentenceTransformer('bert-base-uncased')\n",
    "\n",
    "# Loss function\n",
    "train_loss = losses.CosineSimilarityLoss(model=embedding_model)\n",
    "\n",
    "# Define the training arguments\n",
    "args = SentenceTransformerTrainingArguments(\n",
    "    output_dir=\"augmented_embedding_model\",\n",
    "    num_train_epochs=1,\n",
    "    per_device_train_batch_size=32,\n",
    "    per_device_eval_batch_size=32,\n",
    "    warmup_steps=100,\n",
    "    fp16=True,\n",
    "    eval_steps=100,\n",
    "    logging_steps=100,\n",
    ")\n",
    "\n",
    "# Train model\n",
    "trainer = SentenceTransformerTrainer(\n",
    "    model=embedding_model,\n",
    "    args=args,\n",
    "    train_dataset=train_dataset,\n",
    "    loss=train_loss,\n",
    "    evaluator=evaluator\n",
    ")\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9_NHjK75z58G"
   },
   "outputs": [],
   "source": [
    "# Evaluate our trained model\n",
    "evaluator(embedding_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "fwAaFBHvDcFi"
   },
   "outputs": [],
   "source": [
    "trainer.accelerator.clear()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "CX6lArIH0h1A"
   },
   "source": [
    "**Step 5**: Evaluate without silver dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "wyPBGfxp0D_7"
   },
   "outputs": [],
   "source": [
    "# Combine gold + silver\n",
    "data = pd.concat([gold], ignore_index=True, axis=0)\n",
    "data = data.drop_duplicates(subset=['sentence1', 'sentence2'], keep=\"first\")\n",
    "train_dataset = Dataset.from_pandas(data, preserve_index=False)\n",
    "\n",
    "# Define model\n",
    "embedding_model = SentenceTransformer('bert-base-uncased')\n",
    "\n",
    "# Loss function\n",
    "train_loss = losses.CosineSimilarityLoss(model=embedding_model)\n",
    "\n",
    "# Define the training arguments\n",
    "args = SentenceTransformerTrainingArguments(\n",
    "    output_dir=\"gold_only_embedding_model\",\n",
    "    num_train_epochs=1,\n",
    "    per_device_train_batch_size=32,\n",
    "    per_device_eval_batch_size=32,\n",
    "    warmup_steps=100,\n",
    "    fp16=True,\n",
    "    eval_steps=100,\n",
    "    logging_steps=100,\n",
    ")\n",
    "\n",
    "# Train model\n",
    "trainer = SentenceTransformerTrainer(\n",
    "    model=embedding_model,\n",
    "    args=args,\n",
    "    train_dataset=train_dataset,\n",
    "    loss=train_loss,\n",
    "    evaluator=evaluator\n",
    ")\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6L8_5TLJ0jdK"
   },
   "outputs": [],
   "source": [
    "# Evaluate our trained model\n",
    "evaluator(embedding_model)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "9ZujSUsu0sHU"
   },
   "source": [
    "Compared to using both the silver and gold datasets, using only the gold dataset reduces the performance of the model!"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "qVq3FSZL7gK7"
   },
   "source": [
    "⚠️ **VRAM Clean-up**\n",
    "* `Restart` the notebook in order to clean-up memory if you move on to the next training example."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "6gckDRJ1nUfo"
   },
   "outputs": [],
   "source": [
    "import gc\n",
    "import torch\n",
    "\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "p7RNAKVl3wmM"
   },
   "source": [
    "## **Unsupervised Learning**"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Oq_phjTb31gX"
   },
   "source": [
    "### Tranformer-based Denoising AutoEncoder (TSDAE)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "8yMdUf_WwErS"
   },
   "outputs": [],
   "source": [
    "# Download additional tokenizer\n",
    "import nltk\n",
    "nltk.download('punkt')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ruI-lOZYZt7J"
   },
   "outputs": [],
   "source": [
    "from tqdm import tqdm\n",
    "from datasets import Dataset, load_dataset\n",
    "from sentence_transformers.datasets import DenoisingAutoEncoderDataset\n",
    "\n",
    "# Create a flat list of sentences\n",
    "mnli = load_dataset(\"glue\", \"mnli\", split=\"train\").select(range(25_000))\n",
    "flat_sentences = mnli[\"premise\"] + mnli[\"hypothesis\"]\n",
    "\n",
    "# Add noise to our input data\n",
    "damaged_data = DenoisingAutoEncoderDataset(list(set(flat_sentences)))\n",
    "\n",
    "# Create dataset\n",
    "train_dataset = {\"damaged_sentence\": [], \"original_sentence\": []}\n",
    "for data in tqdm(damaged_data):\n",
    "    train_dataset[\"damaged_sentence\"].append(data.texts[0])\n",
    "    train_dataset[\"original_sentence\"].append(data.texts[1])\n",
    "train_dataset = Dataset.from_dict(train_dataset)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "mymxiQ9A1eQm"
   },
   "outputs": [],
   "source": [
    "train_dataset[0]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "77IuQ8QIjO25"
   },
   "outputs": [],
   "source": [
    "# # Choose a different deletion ratio\n",
    "# flat_sentences = list(set(flat_sentences))\n",
    "# damaged_data = DenoisingAutoEncoderDataset(\n",
    "#     flat_sentences,\n",
    "#     noise_fn=lambda s: DenoisingAutoEncoderDataset.delete(s, del_ratio=0.6)\n",
    "# )"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Tl6CzdwNA1tC"
   },
   "outputs": [],
   "source": [
    "from sentence_transformers.evaluation import EmbeddingSimilarityEvaluator\n",
    "\n",
    "# Create an embedding similarity evaluator for stsb\n",
    "val_sts = load_dataset('glue', 'stsb', split='validation')\n",
    "evaluator = EmbeddingSimilarityEvaluator(\n",
    "    sentences1=val_sts[\"sentence1\"],\n",
    "    sentences2=val_sts[\"sentence2\"],\n",
    "    scores=[score/5 for score in val_sts[\"label\"]],\n",
    "    main_similarity=\"cosine\"\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "pYM298tWlacT"
   },
   "outputs": [],
   "source": [
    "from sentence_transformers import models, SentenceTransformer\n",
    "\n",
    "# Create your embedding model\n",
    "word_embedding_model = models.Transformer('bert-base-uncased')\n",
    "pooling_model = models.Pooling(word_embedding_model.get_word_embedding_dimension(), 'cls')\n",
    "embedding_model = SentenceTransformer(modules=[word_embedding_model, pooling_model])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "5RZ8tQFSlIHm"
   },
   "outputs": [],
   "source": [
    "from sentence_transformers import losses\n",
    "\n",
    "# Use the denoising auto-encoder loss\n",
    "train_loss = losses.DenoisingAutoEncoderLoss(\n",
    "    embedding_model, tie_encoder_decoder=True\n",
    ")\n",
    "train_loss.decoder = train_loss.decoder.to(\"cuda\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "PYApurOS07x0"
   },
   "outputs": [],
   "source": [
    "from sentence_transformers.trainer import SentenceTransformerTrainer\n",
    "from sentence_transformers.training_args import SentenceTransformerTrainingArguments\n",
    "\n",
    "# Define the training arguments\n",
    "args = SentenceTransformerTrainingArguments(\n",
    "    output_dir=\"tsdae_embedding_model\",\n",
    "    num_train_epochs=1,\n",
    "    per_device_train_batch_size=16,\n",
    "    per_device_eval_batch_size=16,\n",
    "    warmup_steps=100,\n",
    "    fp16=True,\n",
    "    eval_steps=100,\n",
    "    logging_steps=100,\n",
    ")\n",
    "\n",
    "# Train model\n",
    "trainer = SentenceTransformerTrainer(\n",
    "    model=embedding_model,\n",
    "    args=args,\n",
    "    train_dataset=train_dataset,\n",
    "    loss=train_loss,\n",
    "    evaluator=evaluator\n",
    ")\n",
    "trainer.train()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "nGxh6fTa7qIh"
   },
   "outputs": [],
   "source": [
    "# Evaluate our trained model\n",
    "evaluator(embedding_model)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "QAd7OKerm-w8"
   },
   "outputs": [],
   "source": [
    "import gc\n",
    "import torch\n",
    "\n",
    "gc.collect()\n",
    "torch.cuda.empty_cache()"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "authorship_tag": "ABX9TyMBa7Gqu8z7LCZ+HmiwqO6K",
   "gpuType": "T4",
   "provenance": [],
   "toc_visible": true
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.12.9"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
